{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rA5Mubike7OJ"
   },
   "source": [
    "##### Copyright 2020 The TensorFlow Authors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:07.314751Z",
     "iopub.status.busy": "2023-11-07T23:32:07.314500Z",
     "iopub.status.idle": "2023-11-07T23:32:07.318521Z",
     "shell.execute_reply": "2023-11-07T23:32:07.317969Z"
    },
    "id": "fY0a3LRYfHUl"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iNz7xXMSsAQa"
   },
   "source": [
    "# 使用 ParameterServerStrategy 进行参数服务器训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6v4D6QfcfTrm"
   },
   "source": [
    "## 概述\n",
    "\n",
    "[参数服务器训练](https://www.usenix.org/system/files/conference/osdi14/osdi14-paper-li_mu.pdf)是一种常见的数据并行方法，用于在多台机器上扩展模型训练。\n",
    "\n",
    "参数服务器训练集群由*工作进程*和*参数服务器*组成。变量在参数服务器上创建，并在每个步骤中由工作进程读取和更新。默认情况下，工作进程会独立读取和更新这些变量，而不会彼此同步。因此，参数服务器式训练有时也称为*异步训练*。\n",
    "\n",
    "在 TensorFlow 2 中，参数服务器训练由 `tf.distribute.ParameterServerStrategy` 类提供支持，该类会将训练步骤分发到一个集群，该集群可扩展到数千个工作进程（伴随着参数服务器）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W1LGfTdgOF-J"
   },
   "source": [
    "### 支持的训练方法\n",
    "\n",
    "支持的训练方法主要有两种：\n",
    "\n",
    "- Keras `Model.fit` API：如果您更喜欢高级抽象和训练处理。如果您正在训练 `tf.keras.Model`，通常建议使用此方法。\n",
    "- 自定义训练循环：如果您更喜欢定义训练循环的详细信息（有关详细信息，请参阅[自定义训练](../customization/custom_training_walkthrough.ipynb)、[从头开始编写训练循环](https://tensorflow.google.cn/guide/keras/writing_a_training_loop_from_scratch)和[使用 Keras 和 MultiWorkerMirroredStrategy 自定义训练循环](multi_worker_with_ctl.ipynb)的指南）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FjbULGvV7NRz"
   },
   "source": [
    "### 具有作业和任务的集群\n",
    "\n",
    "无论选择哪种 API（`Model.fit` 或自定义训练循环），TensorFlow 2 中的分布式训练都涉及一个包含多个 `'jobs'` 的 `'cluster'`，每个作业可能有一个或多个 `'tasks'`。\n",
    "\n",
    "在使用参数服务器训练时，建议具有：\n",
    "\n",
    "- 一个*协调器*作业（作业名称为 `chief`）\n",
    "- 多个*工作进程*作业（作业名称为 `worker`）\n",
    "- 多个*参数服务器*作业（作业名称为 `ps`）\n",
    "\n",
    "*协调器*会创建资源、调度训练任务、编写检查点并处理任务失败。*工作进程*和*参数服务器*会运行 `tf.distribute.Server` 实例来监听来自协调器的请求。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oLV1FbpLtqtB"
   },
   "source": [
    "### 使用 `Model.fit` API 进行参数服务器训练\n",
    "\n",
    "使用 `Model.fit` API 训练参数服务器需要协调器使用 `tf.distribute.ParameterServerStrategy` 对象。与没有策略或具有其他策略的 `Model.fit` 用法类似，工作流包括创建和编译模型、准备回调和调用 `Model.fit`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yJ5AosxFyfzk"
   },
   "source": [
    "### 使用自定义训练循环进行参数服务器训练\n",
    "\n",
    "对于自定义训练循环，`tf.distribute.coordinator.ClusterCoordinator` 类是用于协调器的关键组件。\n",
    "\n",
    "- `ClusterCoordinator` 类需要与 `tf.distribute.ParameterServerStrategy` 对象结合使用。\n",
    "- 此 `tf.distribute.Strategy` 对象用于提供集群的信息并用于定义训练步骤，如[使用 tf.distribute.Strategy 进行自定义训练](custom_training.ipynb)中所示。\n",
    "- 之后，`ClusterCoordinator` 对象会将这些训练步骤的执行分派给远程工作进程。\n",
    "\n",
    "`ClusterCoordinator` 对象提供的最重要的 API 是 `schedule`：\n",
    "\n",
    "- `schedule` API 会将 `tf.function` 排入队列并立即返回一个类似未来的 `RemoteValue` 。\n",
    "- 排队的函数将被分派给后台线程中的远程工作进程，并且它们的 `RemoteValue` 将被异步填充。\n",
    "- 由于 `schedule` 不需要分配工作进程，因此传入的 `tf.function` 可以在任何可用的工作进程上执行。\n",
    "- 如果执行它的工作进程在完成之前变得不可用，则该函数将在另一个可用的工作进程上重试。\n",
    "- 由于上述事实以及函数执行非原子方式的事实，单个函数调用可能会执行多次。\n",
    "\n",
    "除了调度远程函数外，`ClusterCoordinator` 还会帮助在所有工作进程上创建数据集，并在工作进程从故障中恢复后重建这些数据集。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MyDnWjmOje5-"
   },
   "source": [
    "## 教程设置\n",
    "\n",
    "本教程将分支到 `Model.fit` 和自定义训练循环路径，您可以选择适合您需求的路径。“使用 X 进行训练”以外的部分两种路径均适用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:07.322677Z",
     "iopub.status.busy": "2023-11-07T23:32:07.322468Z",
     "iopub.status.idle": "2023-11-07T23:32:09.410904Z",
     "shell.execute_reply": "2023-11-07T23:32:09.409908Z"
    },
    "id": "0-V3LUcIs4a-"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting portpicker\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Downloading portpicker-1.6.0-py3-none-any.whl.metadata (1.5 kB)\r\n",
      "Requirement already satisfied: psutil in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from portpicker) (5.9.6)\r\n",
      "Downloading portpicker-1.6.0-py3-none-any.whl (16 kB)\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Installing collected packages: portpicker\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully installed portpicker-1.6.0\r\n"
     ]
    }
   ],
   "source": [
    "!pip install portpicker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:09.415195Z",
     "iopub.status.busy": "2023-11-07T23:32:09.414923Z",
     "iopub.status.idle": "2023-11-07T23:32:11.893749Z",
     "shell.execute_reply": "2023-11-07T23:32:11.892958Z"
    },
    "id": "GlI_NAVFae3J"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:32:09.863884: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2023-11-07 23:32:09.863934: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2023-11-07 23:32:09.865658: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
     ]
    }
   ],
   "source": [
    "#@title\n",
    "import multiprocessing\n",
    "import os\n",
    "import random\n",
    "import portpicker\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uvwgM2rzgzIC"
   },
   "source": [
    "## 集群设置\n",
    "\n",
    "如上所述，参数服务器训练集群需要一个运行您的训练程序的协调器任务、一个或多个工作进程和运行 TensorFlow 服务器的参数服务器任务 `tf.distribute.Server`，以及可能运行边车评估的附加评估任务（请参阅下面的[边车评估部分](#sidecar_evaluation)）。设置它们的要求是：\n",
    "\n",
    "- 协调器任务需要知道除评估器之外的所有其他 TensorFlow 服务器的地址和端口。\n",
    "- 工作进程和参数服务器需要知道他们需要监听哪个端口。为简单起见，在这些任务上创建 TensorFlow 服务器时，通常可以传入完整的集群信息。\n",
    "- 评估器任务不必知道训练集群的设置。如果知道，则它不应该尝试连接到训练集群。\n",
    "- 工作进程和参数服务器的任务类型应该分别为 `\"worker\"` 和 `\"ps\"`。出于遗留原因，协调器应使用 `\"chief\"` 作为任务类型。\n",
    "\n",
    "在本教程中，您将创建一个进程内集群，以便整个参数服务器训练可以在 Colab 中运行。您将在后面的部分中学习如何设置[真正的集群](#real_clusters)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7UNs7Lm2g19n"
   },
   "source": [
    "### 进程内集群\n",
    "\n",
    "您将首先创建多个 TensorFlow 服务器，稍后再连接它们。请注意，这仅用于本教程的演示目的，在实际训练中，服务器将在 `\"worker\"` 和 `\"ps\"` 机器上启动。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:11.898593Z",
     "iopub.status.busy": "2023-11-07T23:32:11.898184Z",
     "iopub.status.idle": "2023-11-07T23:32:14.001218Z",
     "shell.execute_reply": "2023-11-07T23:32:14.000053Z"
    },
    "id": "FbrP5pXuaoVH"
   },
   "outputs": [],
   "source": [
    "def create_in_process_cluster(num_workers, num_ps):\n",
    "  \"\"\"Creates and starts local servers and returns the cluster_resolver.\"\"\"\n",
    "  worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]\n",
    "  ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]\n",
    "\n",
    "  cluster_dict = {}\n",
    "  cluster_dict[\"worker\"] = [\"localhost:%s\" % port for port in worker_ports]\n",
    "  if num_ps > 0:\n",
    "    cluster_dict[\"ps\"] = [\"localhost:%s\" % port for port in ps_ports]\n",
    "\n",
    "  cluster_spec = tf.train.ClusterSpec(cluster_dict)\n",
    "\n",
    "  # Workers need some inter_ops threads to work properly.\n",
    "  worker_config = tf.compat.v1.ConfigProto()\n",
    "  if multiprocessing.cpu_count() < num_workers + 1:\n",
    "    worker_config.inter_op_parallelism_threads = num_workers + 1\n",
    "\n",
    "  for i in range(num_workers):\n",
    "    tf.distribute.Server(\n",
    "        cluster_spec,\n",
    "        job_name=\"worker\",\n",
    "        task_index=i,\n",
    "        config=worker_config,\n",
    "        protocol=\"grpc\")\n",
    "\n",
    "  for i in range(num_ps):\n",
    "    tf.distribute.Server(\n",
    "        cluster_spec,\n",
    "        job_name=\"ps\",\n",
    "        task_index=i,\n",
    "        protocol=\"grpc\")\n",
    "\n",
    "  cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(\n",
    "      cluster_spec, rpc_layer=\"grpc\")\n",
    "  return cluster_resolver\n",
    "\n",
    "# Set the environment variable to allow reporting worker and ps failure to the\n",
    "# coordinator. This is a workaround and won't be necessary in the future.\n",
    "os.environ[\"GRPC_FAIL_FAST\"] = \"use_caller\"\n",
    "\n",
    "NUM_WORKERS = 3\n",
    "NUM_PS = 2\n",
    "cluster_resolver = create_in_process_cluster(NUM_WORKERS, NUM_PS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pX_91OByt0J2"
   },
   "source": [
    "进程内集群设置经常用于单元测试，例如[这里](https://github.com/tensorflow/tensorflow/blob/eb4c40fc91da260199fa2aed6fe67d36ad49fafd/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py#L447)。\n",
    "\n",
    "本地测试的另一个选择是在本地机器上启动进程，请查看[使用 Keras 进行多工作进程训练](multi_worker_with_keras.ipynb)以获取这种方式的示例。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zyby6M2Jqg6J"
   },
   "source": [
    "## 实例化 ParameterServerStrategy\n",
    "\n",
    "在深入了解训练代码之前，我们来实例化一个 `tf.distribute.ParameterServerStrategy` 对象。请注意，无论您是使用 `Model.fit` 还是自定义训练循环，都需要这样做。`variable_partitioner` 参数将在[变量分片部分](#variable_sharding)进行解释。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:14.006979Z",
     "iopub.status.busy": "2023-11-07T23:32:14.006700Z",
     "iopub.status.idle": "2023-11-07T23:32:14.135732Z",
     "shell.execute_reply": "2023-11-07T23:32:14.134912Z"
    },
    "id": "_YyEPgisrC35"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:`tf.distribute.experimental.ParameterServerStrategy` is initialized with cluster_spec: ClusterSpec({'ps': ['localhost:38883', 'localhost:43499'], 'worker': ['localhost:38671', 'localhost:45527', 'localhost:38193']})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:ParameterServerStrategyV2 is now connecting to cluster with cluster_spec: ClusterSpec({'ps': ['localhost:38883', 'localhost:43499'], 'worker': ['localhost:38671', 'localhost:45527', 'localhost:38193']})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ['/job:chief/replica:0/task:0/device:GPU:0', '/job:chief/replica:0/task:0/device:GPU:1', '/job:chief/replica:0/task:0/device:GPU:2', '/job:chief/replica:0/task:0/device:GPU:3'], variable_device = '/device:CPU:0'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ['/job:chief/replica:0/task:0/device:GPU:0', '/job:chief/replica:0/task:0/device:GPU:1', '/job:chief/replica:0/task:0/device:GPU:2', '/job:chief/replica:0/task:0/device:GPU:3'], variable_device = '/device:CPU:0'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Number of GPUs on workers: 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Number of GPUs on workers: 4\n"
     ]
    }
   ],
   "source": [
    "variable_partitioner = (\n",
    "    tf.distribute.experimental.partitioners.MinSizePartitioner(\n",
    "        min_shard_bytes=(256 << 10),\n",
    "        max_shards=NUM_PS))\n",
    "\n",
    "strategy = tf.distribute.ParameterServerStrategy(\n",
    "    cluster_resolver,\n",
    "    variable_partitioner=variable_partitioner)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WlAQxuMDJ3k9"
   },
   "source": [
    "为了使用 GPU 进行训练，请分配对每个工作进程可见的 GPU。`ParameterServerStrategy` 将使用每个工作进程上的所有可用 GPU，但限制是所有工作进程应该有相同数量的可用  GPU。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QMmBLsf6sEXh"
   },
   "source": [
    "### 变量分片\n",
    "\n",
    "变量分片是指将一个变量拆分为多个较小的变量，这些变量称为*分片*。在访问这些分片时，变量分片可能有助于分配网络负载。这对在多个参数服务器之间分布计算和存储普通变量也很有用，例如，当使用可能不适合单个机器内存的非常大的嵌入时。\n",
    "\n",
    "要启用变量分片，您可以在构造 `ParameterServerStrategy` 对象时传入 `variable_partitioner`。每次创建变量时都会调用 `variable_partitioner`，它预计会返回该变量每个维度上的分片数。提供了一些开箱即用的 `variable_partitioner`，例如 `tf.distribute.experimental.partitioners.MinSizePartitioner`。建议使用基于大小的分区程序（如 `tf.distribute.experimental.partitioners.MinSizePartitioner`）以避免对小变量进行分区，否则可能会对模型训练速度产生负面影响。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1--SxlxtsOb7"
   },
   "source": [
    "当传入 `variable_partitioner`，并且您直接在 `Strategy.scope` 下创建变量时，该变量将成为具有 `variables` 属性的容器类型，该属性提供对分片列表的访问。在大多数情况下，此容器将通过连接所有分片自动转换为张量。因此，它可以用作普通变量。另一方面，一些 TensorFlow 方法（如 `tf.nn.embedding_lookup`）为这种容器类型提供了有效的实现，并且在这些方法中将避免自动连接。\n",
    "\n",
    "有关详细信息，请参阅 `tf.distribute.ParameterServerStrategy` 的 API 文档。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jlOq-O-26O1d"
   },
   "source": [
    "## 使用 `Model.fit` 进行训练\n",
    "\n",
    "<a id=\"training_with_modelfit\"></a>\n",
    "\n",
    "Keras 通过 `Model.fit` 提供了一个易于使用的训练 API，它在后台处理训练循环，具有可覆盖的 `train_step` 的灵活性，以及为 TensorBoard 提供检查点保存或摘要保存等功能的回调。使用 `Model.fit`，只需简单地交换策略对象，即可将相同的训练代码用于其他策略。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oMZ9Cu5J6ZGi"
   },
   "source": [
    "### 输入数据\n",
    "\n",
    "带有 `tf.distribute.ParameterServerStrategy` 的 Keras `Model.fit` 可以采用以下形式的输入数据：`tf.data.Dataset`、`tf.distribute.DistributedDataset`、`tf.keras.utils.experimental.DatasetCreator`，或为了使用方便推荐的 `Dataset` 选项。但是，如果您在使用 `Dataset` 时遇到内存问题，可能需要将 `DatasetCreator` 与可调用的 `dataset_fn` 参数一起使用（有关详细信息，请参阅 `tf.keras.utils.experimental.DatasetCreator` API 文档）。\n",
    "\n",
    "如果将数据集转换为 `tf.data.Dataset`，则应使用 `Dataset.shuffle` 和 `Dataset.repeat`，如下面的代码示例所示。\n",
    "\n",
    "- 带有参数服务器训练的 Keras `Model.fit` 假设每个工作进程接收相同的数据集（除非以不同的方式乱序）。因此，您可以通过调用 `Dataset.shuffle`，确保对数据进行更均匀的迭代。\n",
    "- 由于工作进程不同步，它们可能会在不同的时间完成对数据集的处理。因此，使用参数服务器训练定义周期的最简单方式是使用 `Dataset.repeat`（在不带参数的情况下无限重复数据集）并在 `Model.fit` 调用中指定 `steps_per_epoch` 参数。\n",
    "\n",
    "有关 <code>shuffle</code> 和 `repeat` 的更多详细信息，请参阅 <a>tf.data 指南</a>中的“训练工作流”部分。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:14.142470Z",
     "iopub.status.busy": "2023-11-07T23:32:14.141812Z",
     "iopub.status.idle": "2023-11-07T23:32:14.565107Z",
     "shell.execute_reply": "2023-11-07T23:32:14.564076Z"
    },
    "id": "shAo1CCS7wU1"
   },
   "outputs": [],
   "source": [
    "global_batch_size = 64\n",
    "\n",
    "x = tf.random.uniform((10, 10))\n",
    "y = tf.random.uniform((10,))\n",
    "\n",
    "dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10).repeat()\n",
    "dataset = dataset.batch(global_batch_size)\n",
    "dataset = dataset.prefetch(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v_jhF70K7zON"
   },
   "source": [
    "如果您改为使用 tf.keras.utils.experimental.DatasetCreator 创建数据集，`dataset_fn` 中的代码将在每台工作进程机器的输入设备（通常是 CPU）上调用。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "w60PuWrWwBD4"
   },
   "source": [
    "### 模型构建和编译\n",
    "\n",
    "接下来，您将创建 `tf.keras.Model`，这是一个用于演示目的的简单 `tf.keras.models.Sequential` 模型。随后调用 `Model.compile` 以合并组件，例如优化器、指标和其他参数（例如`steps_per_execution`）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:14.569225Z",
     "iopub.status.busy": "2023-11-07T23:32:14.568960Z",
     "iopub.status.idle": "2023-11-07T23:32:14.612735Z",
     "shell.execute_reply": "2023-11-07T23:32:14.611933Z"
    },
    "id": "PhTHUYaD74vT"
   },
   "outputs": [],
   "source": [
    "with strategy.scope():\n",
    "  model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])\n",
    "\n",
    "  model.compile(tf.keras.optimizers.legacy.SGD(), loss=\"mse\", steps_per_execution=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nWb_Ekm377YX"
   },
   "source": [
    "### 回调和训练\n",
    "\n",
    "<a id=\"callbacks-and-training\"> </a>\n",
    "\n",
    "在调用 Keras `Model.fit` 进行实际训练之前，请为常见任务准备任何所需的[回调](https://tensorflow.google.cn/guide/keras/train_and_evaluate)，例如：\n",
    "\n",
    "- `tf.keras.callbacks.ModelCheckpoint`：以特定频率保存模型，例如在每个周期之后。\n",
    "- `tf.keras.callbacks.BackupAndRestore`：如果集群遇到不可用（例如中止或抢占），则通过备份模型和当前周期数来提供容错。然后，您可以在作业失败后重新启动时恢复训练状态，并从中断的周期开始继续训练。\n",
    "- `tf.keras.callbacks.TensorBoard`：定期将模型日志写入可在 TensorBoard 工具中可视化的摘要文件中。\n",
    "\n",
    "注：出于性能考虑，自定义回调在与 `ParameterServerStrategy` 一起使用时不能覆盖批处理级别的回调。请修改您的自定义回调以进行周期级别调用，并将 `steps_per_epoch` 调整为合适的值。此外，当与 `ParameterServerStrategy` 一起使用时，`steps_per_epoch` 是 `Model.fit` 的必需参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:14.616382Z",
     "iopub.status.busy": "2023-11-07T23:32:14.616146Z",
     "iopub.status.idle": "2023-11-07T23:32:22.438338Z",
     "shell.execute_reply": "2023-11-07T23:32:22.437398Z"
    },
    "id": "3ddUvUZk7_wm"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:462: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.\n",
      "  warnings.warn(\"To make it possible to preserve tf.data options across \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20/20 - 4s - loss: 0.9051 - 4s/epoch - 216ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20/20 - 1s - loss: 0.7227 - 1s/epoch - 60ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:5 out of the last 5 calls to <function MultiDeviceSaver.save.<locals>.tf_function_save at 0x7f475878cb80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:5 out of the last 5 calls to <function MultiDeviceSaver.save.<locals>.tf_function_save at 0x7f475878cb80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:6 out of the last 6 calls to <function MultiDeviceSaver.save.<locals>.tf_function_save at 0x7f47585c0c10> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:6 out of the last 6 calls to <function MultiDeviceSaver.save.<locals>.tf_function_save at 0x7f47585c0c10> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20/20 - 1s - loss: 0.5900 - 578ms/epoch - 29ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20/20 - 1s - loss: 0.4859 - 561ms/epoch - 28ms/step\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/5\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20/20 - 1s - loss: 0.4067 - 568ms/epoch - 28ms/step\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.src.callbacks.History at 0x7f49381e95e0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "working_dir = \"/tmp/my_working_dir\"\n",
    "log_dir = os.path.join(working_dir, \"log\")\n",
    "ckpt_filepath = os.path.join(working_dir, \"ckpt\")\n",
    "backup_dir = os.path.join(working_dir, \"backup\")\n",
    "\n",
    "callbacks = [\n",
    "    tf.keras.callbacks.TensorBoard(log_dir=log_dir),\n",
    "    tf.keras.callbacks.ModelCheckpoint(filepath=ckpt_filepath),\n",
    "    tf.keras.callbacks.BackupAndRestore(backup_dir=backup_dir),\n",
    "]\n",
    "\n",
    "model.fit(dataset, epochs=5, steps_per_epoch=20, callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uWgP1h2z8B3j"
   },
   "source": [
    "### 直接使用 `ClusterCoordinator`（可选）\n",
    "\n",
    "即使您选择 `Model.fit` 训练路径，您也可以选择实例化一个 `tf.distribute.coordinator.ClusterCoordinator` 对象来调度您希望在工作进程上执行的其他函数。有关更多详细信息和示例，请参阅[使用自定义训练循环进行训练](#training_with_custom_training_loop)部分。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GxypEyIthR0z"
   },
   "source": [
    "## 使用自定义训练循环进行训练\n",
    "\n",
    "<a id=\"training_with_custom_training_loop\"> </a>\n",
    "\n",
    "使用带有 `tf.distribute.Strategy` 的自定义训练循环为定义训练循环提供了极大的灵活性。通过上面定义的 `ParameterServerStrategy`（作为 `strategy`），您将使用 `tf.distribute.coordinator.ClusterCoordinator` 将训练步骤的执行分派给远程工作进程。\n",
    "\n",
    "然后，您将创建一个模型、定义一个数据集并定义一个步骤函数，就像您在训练循环中使用其他 `tf.distribute.Strategy` 一样。您可以在[使用  tf.distribute.Strategy 进行自定义训练](custom_training.ipynb)教程中找到更多详细信息。\n",
    "\n",
    "为确保高效的数据集预提取，请使用以下[向远程工作进程分派训练步骤](#dispatch_training_steps_to_remote_workers)部分中提到的推荐分布式数据集创建 API。此外，请确保在 `worker_fn` 中调用 `Strategy.run` 以充分利用分配给工作进程的 GPU。对于使用或不使用 GPU 的训练，其余步骤相同。\n",
    "\n",
    "我们来按以下步骤创建这些组件：\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4QNkCtV8VivM"
   },
   "source": [
    "### 设置数据\n",
    "\n",
    "首先，编写一个创建数据集的函数。\n",
    "\n",
    "如果您想使用 [Keras 预处理层](https://tensorflow.google.cn/guide/keras/preprocessing_layers)或 [Tensorflow Transform 层](https://tensorflow.google.cn/tfx/tutorials/transform/simple)对数据进行预处理，**请在 `dataset_fn` 之外**和 **`Strategy.scope` 下**创建这些层，就像您对任何其他 Keras 层所做的那样。这是因为 `dataset_fn` 将被封装到一个 `tf.function` 中，然后在每个工作进程上执行以生成数据流水线。\n",
    "\n",
    "如果您不遵循上述过程，创建层可能会创建 Tensorflow 状态，这些状态将从 `tf.function` 提升到协调器。因此，在工作进程上访问它们会导致协调器和工作进程之间重复的 RPC 调用，并导致显著的速度下降。。\n",
    "\n",
    "将这些层放在 `Strategy.scope` 下将改为在所有工作进程上创建它们。然后，您将通过 `tf.data.Dataset.map` 在 `dataset_fn` 中应用转换。有关使用分布式输入进行数据预处理的更多信息，请参阅[分布式输入](input.ipynb)教程中的*数据预处理*。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:22.442669Z",
     "iopub.status.busy": "2023-11-07T23:32:22.441990Z",
     "iopub.status.idle": "2023-11-07T23:32:22.510920Z",
     "shell.execute_reply": "2023-11-07T23:32:22.510115Z"
    },
    "id": "2GUwATssauus"
   },
   "outputs": [],
   "source": [
    "feature_vocab = [\n",
    "    \"avenger\", \"ironman\", \"batman\", \"hulk\", \"spiderman\", \"kingkong\", \"wonder_woman\"\n",
    "]\n",
    "label_vocab = [\"yes\", \"no\"]\n",
    "\n",
    "with strategy.scope():\n",
    "  feature_lookup_layer = tf.keras.layers.StringLookup(\n",
    "      vocabulary=feature_vocab,\n",
    "      mask_token=None)\n",
    "  label_lookup_layer = tf.keras.layers.StringLookup(\n",
    "      vocabulary=label_vocab,\n",
    "      num_oov_indices=0,\n",
    "      mask_token=None)\n",
    "\n",
    "  raw_feature_input = tf.keras.layers.Input(\n",
    "      shape=(3,),\n",
    "      dtype=tf.string,\n",
    "      name=\"feature\")\n",
    "  feature_id_input = feature_lookup_layer(raw_feature_input)\n",
    "  feature_preprocess_stage = tf.keras.Model(\n",
    "      {\"features\": raw_feature_input},\n",
    "      feature_id_input)\n",
    "\n",
    "  raw_label_input = tf.keras.layers.Input(\n",
    "      shape=(1,),\n",
    "      dtype=tf.string,\n",
    "      name=\"label\")\n",
    "  label_id_input = label_lookup_layer(raw_label_input)\n",
    "\n",
    "  label_preprocess_stage = tf.keras.Model(\n",
    "      {\"label\": raw_label_input},\n",
    "      label_id_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Jgp8MX_7OR_A"
   },
   "source": [
    "在数据集中生成演练样本："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:22.515094Z",
     "iopub.status.busy": "2023-11-07T23:32:22.514652Z",
     "iopub.status.idle": "2023-11-07T23:32:22.519953Z",
     "shell.execute_reply": "2023-11-07T23:32:22.519259Z"
    },
    "id": "chIY4fFANaFH"
   },
   "outputs": [],
   "source": [
    "def feature_and_label_gen(num_examples=200):\n",
    "  examples = {\"features\": [], \"label\": []}\n",
    "  for _ in range(num_examples):\n",
    "    features = random.sample(feature_vocab, 3)\n",
    "    label = [\"yes\"] if \"avenger\" in features else [\"no\"]\n",
    "    examples[\"features\"].append(features)\n",
    "    examples[\"label\"].append(label)\n",
    "  return examples\n",
    "\n",
    "examples = feature_and_label_gen()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2AtZBya7OeyZ"
   },
   "source": [
    "然后，创建封装在 `dataset_fn` 中的训练数据集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:22.523353Z",
     "iopub.status.busy": "2023-11-07T23:32:22.522781Z",
     "iopub.status.idle": "2023-11-07T23:32:22.527026Z",
     "shell.execute_reply": "2023-11-07T23:32:22.526392Z"
    },
    "id": "Gs0QYRZoNbvw"
   },
   "outputs": [],
   "source": [
    "def dataset_fn(_):\n",
    "  raw_dataset = tf.data.Dataset.from_tensor_slices(examples)\n",
    "\n",
    "  train_dataset = raw_dataset.map(\n",
    "      lambda x: (\n",
    "          {\"features\": feature_preprocess_stage(x[\"features\"])},\n",
    "          label_preprocess_stage(x[\"label\"])\n",
    "      )).shuffle(200).batch(32).repeat()\n",
    "  return train_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IT9PQexJiFtB"
   },
   "source": [
    "### 构建模型\n",
    "\n",
    "接下来，创建模型和其他对象。确保在 `Strategy.scope` 下创建所有变量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:22.530166Z",
     "iopub.status.busy": "2023-11-07T23:32:22.529937Z",
     "iopub.status.idle": "2023-11-07T23:32:22.617303Z",
     "shell.execute_reply": "2023-11-07T23:32:22.616556Z"
    },
    "id": "Quxud1uEazeo"
   },
   "outputs": [],
   "source": [
    "# These variables created under the `Strategy.scope` will be placed on parameter\n",
    "# servers in a round-robin fashion.\n",
    "with strategy.scope():\n",
    "  # Create the model. The input needs to be compatible with Keras processing layers.\n",
    "  model_input = tf.keras.layers.Input(\n",
    "      shape=(3,), dtype=tf.int64, name=\"model_input\")\n",
    "\n",
    "  emb_layer = tf.keras.layers.Embedding(\n",
    "      input_dim=len(feature_lookup_layer.get_vocabulary()), output_dim=16384)\n",
    "  emb_output = tf.reduce_mean(emb_layer(model_input), axis=1)\n",
    "  dense_output = tf.keras.layers.Dense(units=1, activation=\"sigmoid\")(emb_output)\n",
    "  model = tf.keras.Model({\"features\": model_input}, dense_output)\n",
    "\n",
    "  optimizer = tf.keras.optimizers.legacy.RMSprop(learning_rate=0.1)\n",
    "  accuracy = tf.keras.metrics.Accuracy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iyuxiqCQU50m"
   },
   "source": [
    "我们来确认一下，使用 `FixedShardsPartitioner` 将所有变量拆分成两个分片，并且每个分片都分配给了不同的参数服务器："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:22.621213Z",
     "iopub.status.busy": "2023-11-07T23:32:22.620765Z",
     "iopub.status.idle": "2023-11-07T23:32:22.625620Z",
     "shell.execute_reply": "2023-11-07T23:32:22.624927Z"
    },
    "id": "04r1nO4WVDO1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/job:ps/replica:0/task:1/device:CPU:0\n",
      "/job:ps/replica:0/task:0/device:CPU:0\n"
     ]
    }
   ],
   "source": [
    "assert len(emb_layer.weights) == 2\n",
    "assert emb_layer.weights[0].shape == (4, 16384)\n",
    "assert emb_layer.weights[1].shape == (4, 16384)\n",
    "\n",
    "print(emb_layer.weights[0].device)\n",
    "print(emb_layer.weights[1].device)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lWhfXZLRiHyM"
   },
   "source": [
    "### 定义训练步骤\n",
    "\n",
    "第三步，创建封装在 `tf.function` 中的训练步骤："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:22.629081Z",
     "iopub.status.busy": "2023-11-07T23:32:22.628699Z",
     "iopub.status.idle": "2023-11-07T23:32:22.634815Z",
     "shell.execute_reply": "2023-11-07T23:32:22.634160Z"
    },
    "id": "aNNVo0bFa1K9"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def step_fn(iterator):\n",
    "\n",
    "  def replica_fn(batch_data, labels):\n",
    "    with tf.GradientTape() as tape:\n",
    "      pred = model(batch_data, training=True)\n",
    "      per_example_loss = tf.keras.losses.BinaryCrossentropy(\n",
    "          reduction=tf.keras.losses.Reduction.NONE)(labels, pred)\n",
    "      loss = tf.nn.compute_average_loss(per_example_loss)\n",
    "      gradients = tape.gradient(loss, model.trainable_variables)\n",
    "\n",
    "    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "\n",
    "    actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)\n",
    "    accuracy.update_state(labels, actual_pred)\n",
    "    return loss\n",
    "\n",
    "  batch_data, labels = next(iterator)\n",
    "  losses = strategy.run(replica_fn, args=(batch_data, labels))\n",
    "  return strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rvrYQUeYiLNy"
   },
   "source": [
    "在上面的训练步骤函数中，在 `step_fn` 中调用`Strategy.run` 和 `Strategy.reduce` 可以支持每个工作进程使用多个 GPU。如果工作进程分配了 GPU，则 `Strategy.run` 会将数据集分布在多个副本上。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GPJ3PV_L2zAY"
   },
   "source": [
    "### 向远程工作进程分派训练步骤\n",
    "\n",
    "<a id=\"dispatch_training_steps_to_remote_workers\"> </a>\n",
    "\n",
    "在 `ParameterServerStrategy` 定义了所有计算之后，您将使用 `tf.distribute.coordinator.ClusterCoordinator` 类创建资源并将训练步骤分发给远程工作进程。\n",
    "\n",
    "我们先创建一个 `ClusterCoordinator` 对象，传入策略对象："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:22.638171Z",
     "iopub.status.busy": "2023-11-07T23:32:22.637943Z",
     "iopub.status.idle": "2023-11-07T23:32:22.641176Z",
     "shell.execute_reply": "2023-11-07T23:32:22.640542Z"
    },
    "id": "DpcMlH7Pa3DB"
   },
   "outputs": [],
   "source": [
    "coordinator = tf.distribute.coordinator.ClusterCoordinator(strategy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-xRIgKxciOSe"
   },
   "source": [
    "然后，使用 `ClusterCoordinator.create_per_worker_dataset` API 创建每个工作进程的数据集和迭代器，这会将数据集复制到所有工作进程。在下面的 `per_worker_dataset_fn` 中，建议将 `dataset_fn` 封装到 `strategy.distribute_datasets_from_function` 中，以便无缝且高效地预提取到 GPU。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:22.644707Z",
     "iopub.status.busy": "2023-11-07T23:32:22.644160Z",
     "iopub.status.idle": "2023-11-07T23:32:22.732796Z",
     "shell.execute_reply": "2023-11-07T23:32:22.731926Z"
    },
    "id": "h9DCvTJTa4Q2"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def per_worker_dataset_fn():\n",
    "  return strategy.distribute_datasets_from_function(dataset_fn)\n",
    "\n",
    "per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)\n",
    "per_worker_iterator = iter(per_worker_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "i2pnOx78iRwW"
   },
   "source": [
    "最后一步是使用 `ClusterCoordinator.schedule` 将计算分配给远程工作进程：\n",
    "\n",
    "- `schedule` 方法会将一个 `tf.function` 排入队列，并立即返回一个类似未来的 `RemoteValue`。排队的函数将被分派给后台线程中的远程工作进程，并且 `RemoteValue` 将被异步填充。\n",
    "- `join` 方法 (`ClusterCoordinator.join`) 可用于等待所有调度函数执行完毕。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:22.737079Z",
     "iopub.status.busy": "2023-11-07T23:32:22.736533Z",
     "iopub.status.idle": "2023-11-07T23:32:28.602482Z",
     "shell.execute_reply": "2023-11-07T23:32:28.601502Z"
    },
    "id": "gmPvactfa6Eh"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished epoch 0, accuracy is 0.574675.\n",
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished epoch 1, accuracy is 0.492958.\n",
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished epoch 2, accuracy is 0.927817.\n",
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished epoch 3, accuracy is 1.000000.\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 4\n",
    "steps_per_epoch = 5\n",
    "for i in range(num_epochs):\n",
    "  accuracy.reset_states()\n",
    "  for _ in range(steps_per_epoch):\n",
    "    coordinator.schedule(step_fn, args=(per_worker_iterator,))\n",
    "  # Wait at epoch boundaries.\n",
    "  coordinator.join()\n",
    "  print(\"Finished epoch %d, accuracy is %f.\" % (i, accuracy.result().numpy()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WBn-gn-OP3DR"
   },
   "source": [
    "以下是提取 `RemoteValue` 结果的方式："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:28.606498Z",
     "iopub.status.busy": "2023-11-07T23:32:28.605916Z",
     "iopub.status.idle": "2023-11-07T23:32:28.648580Z",
     "shell.execute_reply": "2023-11-07T23:32:28.647604Z"
    },
    "id": "-15a2I_lQDO1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final loss is 0.000000\n"
     ]
    }
   ],
   "source": [
    "loss = coordinator.schedule(step_fn, args=(per_worker_iterator,))\n",
    "print(\"Final loss is %f\" % loss.fetch())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "htY4QKc9iXg9"
   },
   "source": [
    "或者，您可以启动所有步骤并在等待完成时执行某些操作：\n",
    "\n",
    "```python\n",
    "for _ in range(total_steps):\n",
    "  coordinator.schedule(step_fn, args=(per_worker_iterator,))\n",
    "while not coordinator.done():\n",
    "  time.sleep(10)\n",
    "  # Do something like logging metrics or writing checkpoints.\n",
    "```\n",
    "\n",
    "有关此特定示例的完整训练和应用工作流，请查看此[测试](https://github.com/keras-team/keras/blob/master/keras/integration_test/parameter_server_keras_preprocessing_test.py)。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kzNsj2GR3BGs"
   },
   "source": [
    "### 有关数据集创建的更多信息\n",
    "\n",
    "上述代码中的数据集是使用 `ClusterCoordinator.create_per_worker_dataset` API 创建的。它为每个工作进程创建一个数据集并返回一个容器对象。您可以对其调用 `iter` 方法来创建按工作进程的迭代器。按工作进程的迭代器包含每个工作进程一个迭代器，并且工作进程的相应切片将在传递给 `ClusterCoordinator.schedule` 方法的函数的输入参数中被替换，然后在特定工作进程上执行该函数。\n",
    "\n",
    "`ClusterCoordinator.schedule` 方法假设工作进程等价，因此假设不同工作进程上的数据集相同（除了它们可能会以不同方式乱序）。因此，还建议重复数据集，并安排有限数量的步骤，而不是依赖于从数据集中接收 `OutOfRangeError`。\n",
    "\n",
    "另一个重要的注意事项是，`tf.data` 数据集不支持跨任务边界的隐式序列化和反序列化。因此，在传递给 `ClusterCoordinator.create_per_worker_dataset` 的函数内创建整个数据集非常重要。`create_per_worker_dataset` API 也可以直接将 `tf.data.Dataset` 或 `tf.distribute.DistributedDataset` 作为输入。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LcfdI_M83lAM"
   },
   "source": [
    "## 评估\n",
    "\n",
    "使用 `tf.distribute.ParameterServerStrategy` 训练执行评估的两种主要方式是内联评估和边车评估。每种方式都有自己的优点和缺点，如下所述。如果您没有偏好，建议使用内联评估方法。对于使用 `Model.fit` 的用户，`Model.evaluate` 在后台使用内联（分布式）评估。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oiG8EhcY3gA1"
   },
   "source": [
    "### 内联评估\n",
    "\n",
    "在这种方法中，协调器在训练和评估之间交替进行，因此称为*内联评估*。\n",
    "\n",
    "内联评估有几个好处。例如：\n",
    "\n",
    "- 它可以支持单个任务无法容纳的大型评估模型和评估数据集。\n",
    "- 评估结果可用于为下一个周期的训练做出决策（例如是否提前停止训练）。\n",
    "\n",
    "实现内联评估有两种方式：直接评估和分布式评估。\n",
    "\n",
    "- **直接评估**：对于小型模型和评估数据集，协调器可以使用协调器上的评估数据集直接在分布式模型上运行评估："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:28.652996Z",
     "iopub.status.busy": "2023-11-07T23:32:28.652274Z",
     "iopub.status.idle": "2023-11-07T23:32:28.948873Z",
     "shell.execute_reply": "2023-11-07T23:32:28.947892Z"
    },
    "id": "WakiAakoaHVn"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation accuracy: 1.000000\n"
     ]
    }
   ],
   "source": [
    "eval_dataset = tf.data.Dataset.from_tensor_slices(\n",
    "    feature_and_label_gen(num_examples=16)).map(\n",
    "          lambda x: (\n",
    "              {\"features\": feature_preprocess_stage(x[\"features\"])},\n",
    "              label_preprocess_stage(x[\"label\"])\n",
    "          )).batch(8)\n",
    "\n",
    "eval_accuracy = tf.keras.metrics.Accuracy()\n",
    "\n",
    "for batch_data, labels in eval_dataset:\n",
    "  pred = model(batch_data, training=False)\n",
    "  actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)\n",
    "  eval_accuracy.update_state(labels, actual_pred)\n",
    "\n",
    "print(\"Evaluation accuracy: %f\" % eval_accuracy.result())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MKGHbdI7aGoJ"
   },
   "source": [
    "- **分布式评估**：对于无法直接在协调器上运行的大型模型或数据集，协调器任务可以通过 `ClusterCoordinator.schedule`/`ClusterCoordinator.join` 方法将评估任务分配给工作进程："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:32:28.952473Z",
     "iopub.status.busy": "2023-11-07T23:32:28.952183Z",
     "iopub.status.idle": "2023-11-07T23:32:30.138306Z",
     "shell.execute_reply": "2023-11-07T23:32:30.137241Z"
    },
    "id": "XcHNHJpDgEvK"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:4 GPUs are allocated per worker. Please use DistributedDataset by calling strategy.experimental_distribute_dataset or strategy.distribute_datasets_from_function to make best use of GPU resources\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:4 GPUs are allocated per worker. Please use DistributedDataset by calling strategy.experimental_distribute_dataset or strategy.distribute_datasets_from_function to make best use of GPU resources\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:4 GPUs are allocated per worker. Please use DistributedDataset by calling strategy.experimental_distribute_dataset or strategy.distribute_datasets_from_function to make best use of GPU resources\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:4 GPUs are allocated per worker. Please use DistributedDataset by calling strategy.experimental_distribute_dataset or strategy.distribute_datasets_from_function to make best use of GPU resources\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:4 GPUs are allocated per worker. Please use DistributedDataset by calling strategy.experimental_distribute_dataset or strategy.distribute_datasets_from_function to make best use of GPU resources\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:4 GPUs are allocated per worker. Please use DistributedDataset by calling strategy.experimental_distribute_dataset or strategy.distribute_datasets_from_function to make best use of GPU resources\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Waiting for all global closures to be finished.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation accuracy: 1.000000\n"
     ]
    }
   ],
   "source": [
    "with strategy.scope():\n",
    "  # Define the eval metric on parameter servers.\n",
    "  eval_accuracy = tf.keras.metrics.Accuracy()\n",
    "\n",
    "@tf.function\n",
    "def eval_step(iterator):\n",
    "  def replica_fn(batch_data, labels):\n",
    "    pred = model(batch_data, training=False)\n",
    "    actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)\n",
    "    eval_accuracy.update_state(labels, actual_pred)\n",
    "  batch_data, labels = next(iterator)\n",
    "  strategy.run(replica_fn, args=(batch_data, labels))\n",
    "\n",
    "def eval_dataset_fn():\n",
    "  return tf.data.Dataset.from_tensor_slices(\n",
    "      feature_and_label_gen(num_examples=16)).map(\n",
    "          lambda x: (\n",
    "              {\"features\": feature_preprocess_stage(x[\"features\"])},\n",
    "              label_preprocess_stage(x[\"label\"])\n",
    "          )).shuffle(16).repeat().batch(8)\n",
    "\n",
    "per_worker_eval_dataset = coordinator.create_per_worker_dataset(eval_dataset_fn)\n",
    "per_worker_eval_iterator = iter(per_worker_eval_dataset)\n",
    "\n",
    "eval_steps_per_epoch = 2\n",
    "for _ in range(eval_steps_per_epoch):\n",
    "  coordinator.schedule(eval_step, args=(per_worker_eval_iterator,))\n",
    "coordinator.join()\n",
    "print(\"Evaluation accuracy: %f\" % eval_accuracy.result())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cKrQktZX5z7a"
   },
   "source": [
    "#### 启用精确一次评估\n",
    "\n",
    "<a id=\"exactly_once_evaluation\"></a>\n",
    "\n",
    "`tf.distribute.coordinator.ClusterCoordinator` 的 `schedule` 和 `join` 方法默认不支持访问保证或精确一次语义。换句话说，在上面的示例中，不能保证数据集中的所有评估示例都将被评估一次；某些可能不会被访问，某些可能会被多次评估。\n",
    "\n",
    "精确一次评估可能更适合减少跨周期评估的方差，并改进通过提前停止、超参数调整或其他方法完成的模型选择。可以通过不同的方式启用精确一次评估：\n",
    "\n",
    "- 使用 `Model.fit/.evaluate` 工作流，可以通过向 `Model.compile` 添加实参来启用。请参阅 `pss_evaluation_shards` 实参的文档。\n",
    "- `tf.data` 服务 API 可以用于在使用 `ParameterServerStrategy` 时为评估提供精确一次性访问（请参阅 `tf.data.experimental.service` API 文档的 *Dynamic Sharding* 部分）。\n",
    "- [边车评估](#sidecar_evaluation)默认提供了精确一次评估，因为评估是在单台机器上进行的。不过，这可能比跨多个工作进程执行评估慢得多。\n",
    "\n",
    "第一个选项，使用 `Model.compile`，是大多数用户建议的解决方案。\n",
    "\n",
    "精确一次评估存在一些限制：\n",
    "\n",
    "- 不支持编写具有精确一次访问保证的自定义分布式评估循环。如果您需要此功能，请在 GitHub 上提交议题。\n",
    "- 无法自动处理使用 `Layer.add_metric` API 的指标的计算。这些应该从评估中排除，或者重新制作成 `Metric` 对象。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "H40X-9Gs3i7_"
   },
   "source": [
    "### 边车评估\n",
    "\n",
    "<a id=\"sidecar_evaluation\"></a>\n",
    "\n",
    "在 `tf.distribute.ParameterServerStrategy` 训练中定义和运行评估循环的另一种方法称为*边车评估*，您可以在其中创建一个专用的评估器任务，重复读取检查点并在最新的检查点上运行评估（有关检查点的更多详细信息，请参阅[此指南](../../guide/checkpoint.ipynb)）。协调程序任务和工作进程任务不花费任何时间进行评估，因此对于固定次数的迭代，整体训练时间应该比使用其他评估方法更短。但是，它需要额外的评估器任务和定期检查点来触发评估。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HonyjnXK9-ys"
   },
   "source": [
    "要为边车评估编写评估循环，您有两种选择：\n",
    "\n",
    "1. 使用 `tf.keras.utils.SidecarEvaluator` API。\n",
    "2. 创建自定义评估循环。\n",
    "\n",
    "有关选项 1 的更多详细信息，请参阅 `tf.keras.utils.SidecarEvaluator` API 文档。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "U_c0EiwB88OG"
   },
   "source": [
    "边车评估仅支持单个任务。这表示：\n",
    "\n",
    "- 保证每个样本都被评估一次。如果评估器被抢占或重新启动，它只会从最新的检查点重新启动评估循环，并且在重新启动之前完成的部分评估进度会被丢弃。\n",
    "\n",
    "- 但是，对单个任务运行评估意味着完整的评估可能需要很长时间。\n",
    "\n",
    "- 如果模型的大小太大而无法放入评估器的内存中，则单个边车评估不适用。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VNJoWVc797B1"
   },
   "source": [
    "另一个注意事项是，`tf.keras.utils.SidecarEvaluator` 实现和下面的自定义评估循环可能会跳过一些检查点，因为它总是会选择可用的最新检查点，并且在一个评估周期中，可以从训练集群产生多个检查点。您可以编写一个自定义评估循环来评估每个检查点，但本教程不涉及此内容。另一方面，如果生成检查点的频率低于运行评估所需的时间，它可能会闲置。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G5jopxBd85Ji"
   },
   "source": [
    "自定义评估循环提供了对细节的更多控制，例如选择要评估的检查点，或提供与评估一起运行的任何附加逻辑。以下是一个可能的自定义边车评估循环：\n",
    "\n",
    "```python\n",
    "checkpoint_dir = ...\n",
    "eval_model = ...\n",
    "eval_data = ...\n",
    "checkpoint = tf.train.Checkpoint(model=eval_model)\n",
    "\n",
    "for latest_checkpoint in tf.train.checkpoints_iterator(\n",
    "    checkpoint_dir):\n",
    "  try:\n",
    "    checkpoint.restore(latest_checkpoint).expect_partial()\n",
    "  except (tf.errors.OpError,) as e:\n",
    "    # checkpoint may be deleted by training when it is about to read it.\n",
    "    continue\n",
    "\n",
    "  # Optionally add callbacks to write summaries.\n",
    "  eval_model.evaluate(eval_data)\n",
    "\n",
    "  # Evaluation finishes when it has evaluated the last epoch.\n",
    "  if latest_checkpoint.endswith('-{}'.format(train_epochs)):\n",
    "    break\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9TkNbtpPhFRQ"
   },
   "source": [
    "## 现实世界中的集群\n",
    "\n",
    "<a id=\"real_clusters\"></a>\n",
    "\n",
    "注：此部分不是运行此页面中的教程代码所必需的内容。\n",
    "\n",
    "在真实的生产环境中，您将在不同的机器上运行不同进程中的所有任务。在每个任务上配置集群信息的最简单方法是设置 `\"TF_CONFIG\"` 环境变量并使用 `tf.distribute.cluster_resolver.TFConfigClusterResolver` 来解析 `\"TF_CONFIG\"`。\n",
    "\n",
    "有关 `\"TF_CONFIG\"` 环境变量的一般描述，请参阅[分布式训练](../../guide/distributed_training.ipynb)指南中的“设置 `TF_CONFIG` 环境变量”部分。\n",
    "\n",
    "如果您使用 Kubernetes 或其他配置模板开始您的训练任务，则这些模板可能已经为您设置了 `“TF_CONFIG\"`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "n7AK9SJGt3tQ"
   },
   "source": [
    "### 设置 `\"TF_CONFIG\"` 环境变量\n",
    "\n",
    "假设您有 3 个工作进程和两个参数服务器。那么工作进程 1 的 `\"TF_CONFIG\"`可以是：\n",
    "\n",
    "```python\n",
    "os.environ[\"TF_CONFIG\"] = json.dumps({\n",
    "    \"cluster\": {\n",
    "        \"worker\": [\"host1:port\", \"host2:port\", \"host3:port\"],\n",
    "        \"ps\": [\"host4:port\", \"host5:port\"],\n",
    "        \"chief\": [\"host6:port\"]\n",
    "    },\n",
    "    \"task\": {\"type\": \"worker\", \"index\": 1}\n",
    "})\n",
    "```\n",
    "\n",
    "评估器的 `\"TF_CONFIG\"` 可以是：\n",
    "\n",
    "```python\n",
    "os.environ[\"TF_CONFIG\"] = json.dumps({\n",
    "    \"cluster\": {\n",
    "        \"evaluator\": [\"host7:port\"]\n",
    "    },\n",
    "    \"task\": {\"type\": \"evaluator\", \"index\": 0}\n",
    "})\n",
    "```\n",
    "\n",
    "上述评估器的 `\"TF_CONFIG\"` 字符串中的 `\"cluster\"` 部分为可选。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fZRjMS0pt1LM"
   },
   "source": [
    "### 如果您对所有任务使用相同的二进制文件\n",
    "\n",
    "如果您更喜欢使用单个二进制文件运行所有这些任务，则需要让您的程序在一开始就分支到不同的角色：\n",
    "\n",
    "```python\n",
    "cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()\n",
    "if cluster_resolver.task_type in (\"worker\", \"ps\"):\n",
    "  # Start a TensorFlow server and wait.\n",
    "elif cluster_resolver.task_type == \"evaluator\":\n",
    "  # Run sidecar evaluation\n",
    "else:\n",
    "  # Run the coordinator.\n",
    "```\n",
    "\n",
    "以下代码会启动 TensorFlow 服务器并等待，这对 `\"worker\"` 和 `\"ps\"` 角色很有用：\n",
    "\n",
    "```python\n",
    "# Set the environment variable to allow reporting worker and ps failure to the\n",
    "# coordinator. This is a workaround and won't be necessary in the future.\n",
    "os.environ[\"GRPC_FAIL_FAST\"] = \"use_caller\"\n",
    "\n",
    "server = tf.distribute.Server(\n",
    "    cluster_resolver.cluster_spec(),\n",
    "    job_name=cluster_resolver.task_type,\n",
    "    task_index=cluster_resolver.task_id,\n",
    "    protocol=cluster_resolver.rpc_layer or \"grpc\",\n",
    "    start=True)\n",
    "server.join()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZWdYfK593eOL"
   },
   "source": [
    "## 处理任务失败"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Bl9eK5r13cOv"
   },
   "source": [
    "### 工作进程故障\n",
    "\n",
    "`tf.distribute.coordinator.ClusterCoordinator` 自定义训练循环和 `Model.fit` 方法都为工作进程故障提供了内置的容错能力。工作进程恢复后，`ClusterCoordinator` 会在工作进程上调用数据集重新创建。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aP0OHZ1-Ne-B"
   },
   "source": [
    "### 参数服务器或协调器故障\n",
    "\n",
    "但是，当协调器看到参数服务器错误时，它会立即引发 `UnavailableError` 或 `AbortedError`。在这种情况下，您可以重新启动协调器。协调器本身也可能变得不可用。因此，为了不丢失训练进度，建议使用以下工具："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "f7m7Itoz8lsI"
   },
   "source": [
    "- 对于 `Model.fit`，您应该使用 `BackupAndRestore` 回调，它会自动处理进度保存和恢复。有关示例，请参阅上面的[回调和训练](#callbacks-and-training)部分。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-XlLyJp53Z8A"
   },
   "source": [
    "- 对于自定义训练循环，您应该定期检查模型变量并在训练开始之前从检查点加载模型变量（如果有的话）。如果优化器设置了检查点，则可以从 `optimizer.iterations` 中大致推断出训练进度：\n",
    "\n",
    "```python\n",
    "checkpoint_manager = tf.train.CheckpointManager(\n",
    "    tf.train.Checkpoint(model=model, optimizer=optimizer),\n",
    "    checkpoint_dir,\n",
    "    max_to_keep=3)\n",
    "if checkpoint_manager.latest_checkpoint:\n",
    "  checkpoint = checkpoint_manager.checkpoint\n",
    "  checkpoint.restore(\n",
    "      checkpoint_manager.latest_checkpoint).assert_existing_objects_matched()\n",
    "\n",
    "global_steps = int(optimizer.iterations.numpy())\n",
    "starting_epoch = global_steps // steps_per_epoch\n",
    "\n",
    "for _ in range(starting_epoch, num_epochs):\n",
    "  for _ in range(steps_per_epoch):\n",
    "    coordinator.schedule(step_fn, args=(per_worker_iterator,))\n",
    "  coordinator.join()\n",
    "  checkpoint_manager.save()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PlN1P7C53XK9"
   },
   "source": [
    "### 提取 `RemoteValue`\n",
    "\n",
    "如果函数执行成功，则保证提取 `RemoteValue` 成功。这是因为当前函数执行后，返回值会立即复制到协调器。如果在复制过程中出现任何工作进程故障，该函数将在另一个可用的工作进程上重试。因此，如果要优化性能，可以在没有返回值的情况下调度函数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iZcR_xNZ3UdU"
   },
   "source": [
    "## 错误报告\n",
    "\n",
    "协调器发现来自参数服务器的错误（如 `UnavailableError`）或其他应用错误（如来自 `tf.debugging.check_numerics` 的 `InvalidArgument`）时，它将在引发错误之前取消所有挂起和排队的函数。提取它们对应的 `RemoteValue` 将引发 `CancelledError`。\n",
    "\n",
    "引发错误后，协调器将不会引发相同的错误或来自被取消函数的任何错误。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QfhbXH-j3NVw"
   },
   "source": [
    "## 性能改进\n",
    "\n",
    "当您使用 `tf.distribute.ParameterServerStrategy` 和 `tf.distribute.coordinator.ClusterCoordinator` 进行训练时，由于几个原因，您可能会遇到性能问题。\n",
    "\n",
    "一个常见的原因是参数服务器的负载不平衡，一些负载很重的参数服务器已经达到容量。也可能有多个根本原因。缓解此问题的一些简单方法是：\n",
    "\n",
    "1. 在构造 `ParameterServerStrategy` 时通过指定 `variable_partitioner` 对大型模型变量进行分片。\n",
    "2. 尽可能避免在一个步骤中创建所有参数服务器都需要的热点变量。例如，在优化器中使用恒定的学习率或子类 `tf.keras.optimizers.schedules.LearningRateSchedule`，因为默认行为是学习率将成为放置在特定参数服务器上的变量，并在每个步骤中由所有其他参数服务器请求。\n",
    "3. 在将大型词汇表传递给 Keras 预处理层之前，对其进行乱序。\n",
    "\n",
    "性能问题的另一个可能原因是协调器。`schedule`/`join` 的实现基于 Python，因此可能会有线程开销。此外，协调器和工作进程之间的延迟可能很大。如果是这种情况，请按以下步骤进行操作：\n",
    "\n",
    "- 对于 `Model.fit`，您可以将 `Model.compile` 中提供的 `steps_per_execution` 参数设置为大于 1 的值。\n",
    "\n",
    "- 对于自定义训练循环，您可以将多个步骤打包到单个 `tf.function` 中：\n",
    "\n",
    "```python\n",
    "steps_per_invocation = 10\n",
    "\n",
    "@tf.function\n",
    "def step_fn(iterator):\n",
    "  for _ in range(steps_per_invocation):\n",
    "    features, labels = next(iterator)\n",
    "    def replica_fn(features, labels):\n",
    "      ...\n",
    "\n",
    "    strategy.run(replica_fn, args=(features, labels))\n",
    "```\n",
    "\n",
    "随着库的进一步优化，希望大多数用户未来不必手动打包步骤。\n",
    "\n",
    "此外，提高性能的一个小技巧是调度没有返回值的函数，如上面[处理任务失败部分](#handling_task_failure)所述。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "chu5F7M_JmVk"
   },
   "source": [
    "## 已知限制\n",
    "\n",
    "<a id=\"known_limitations\"> </a>\n",
    "\n",
    "大多数已知限制已在上述部分中进行了介绍。此部分进行了总结。\n",
    "\n",
    "### `ParameterServerStrategy` 通用\n",
    "\n",
    "- `os.environment[\"grpc_fail_fast\"]=\"use_caller\"` 在包括协调器在内的每个任务上都需要，以使容错正常工作。\n",
    "- 不支持同步参数服务器训练。\n",
    "- 通常需要将多个步骤打包到一个函数中以实现最佳性能。\n",
    "- 不支持通过 `tf.saved_model.load` 加载包含分片变量的 saved_model。请注意，使用 TensorFlow Serving 加载此类 saved_model 应该可行（有关详细信息，请参阅[应用教程](https://tensorflow.google.cn/tfx/tutorials/serving/rest_simple)）。\n",
    "- 不支持在不重启协调器任务的情况下从参数服务器故障中恢复。\n",
    "- `tf.lookup.StaticHashTable` 的创建，通常由一些 Keras 预处理层使用，例如 `tf.keras.layers.IntegerLookup`、`tf.keras.layers.StringLookup` 和 `tf.keras.layers.TextVectorization`，应该放在 `Strategy.scope` 下。否则，资源将被放置在协调器上，并且从工作进程到协调器查找 RPC 会产生性能影响。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2MKBF0RPSvzB"
   },
   "source": [
    "### `Model.fit` 细节\n",
    "\n",
    "- 在 `Model.fit` 中需要 `steps_per_epoch` 参数。您可以选择一个在周期内提供适当间隔的值。\n",
    "- 出于性能原因，`ParameterServerStrategy` 不支持具有批处理级别调用的自定义回调。您应该使用适当选取的 `steps_per_epoch` 将这些调用转换为周期级别的调用，以便每隔 `steps_per_epoch` 个步骤调用它们一次。内置回调不受影响：它们的批处理级别调用已被修改为高性能。正在计划支持对 `ParameterServerStrategy` 的批处理级别调用。\n",
    "- 出于同样的原因，与其他策略不同，进度条和指标仅在周期边界处记录。\n",
    "- 不支持 `run_eagerly`。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wvY-mg35Sx5L"
   },
   "source": [
    "### 自定义训练循环细节\n",
    "\n",
    "- 一般而言，`ClusterCoordinator.schedule` 不支持数据集的访问保证，但可以通过 `Model.fit/.evaluate` 实现评估的访问保证。请参阅 [启用精确一次评估](#exactly_once_evaluation)。\n",
    "- 当 `ClusterCoordinator.create_per_worker_dataset` 与可调用对象作为输入一起使用时，必须在传递给它的函数内创建整个数据集。\n",
    "- `tf.data.Options` 会在由 `ClusterCoordinator.create_per_worker_dataset` 创建的数据集中被忽略。"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "parameter_server_training.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
